import sys

import numpy as np
import pytest
from hypothesis import given, settings, strategies

import xgboost as xgb
from xgboost import testing as tm
from xgboost.testing.data import check_inf
from xgboost.testing.data_iter import run_mixed_sparsity
from xgboost.testing.quantile_dmatrix import (
    check_categorical_strings,
    check_ref_quantile_cut,
)
from xgboost.testing.utils import predictor_equal

sys.path.append("tests/python")
import test_quantile_dmatrix as tqd


class TestQuantileDMatrix:
    cputest = tqd.TestQuantileDMatrix()

    @pytest.mark.skipif(**tm.no_cupy())
    def test_dmatrix_feature_weights(self) -> None:
        import cupy as cp

        rng = cp.random.RandomState(np.uint64(1994))
        data = rng.randn(5, 5)
        m = xgb.DMatrix(data)

        feature_weights = rng.uniform(size=5)
        m.set_info(feature_weights=feature_weights)

        cp.testing.assert_array_equal(
            cp.array(m.get_float_info("feature_weights")),
            feature_weights.astype(np.float32),
        )

    def test_categorical_strings(self) -> None:
        check_categorical_strings("cuda")

    @pytest.mark.skipif(**tm.no_cupy())
    def test_dmatrix_cupy_init(self) -> None:
        import cupy as cp

        data = cp.random.randn(5, 5)
        xgb.QuantileDMatrix(data, cp.ones(5, dtype=np.float64))

    @pytest.mark.parametrize(
        "on_device,device",
        [(True, "cpu"), (False, "cuda"), (False, "cpu"), (True, "cuda")],
    )
    def test_initialization(self, on_device: bool, device: str) -> None:
        n_samples, n_features, max_bin = 64, 3, 16
        X, y, w = tm.make_batches(
            n_samples,
            n_features=n_features,
            n_batches=1,
            use_cupy=on_device,
        )

        tree_method = "hist"
        # Init SparsePage
        Xy = xgb.DMatrix(X[0], y[0], weight=w[0])
        # Init GIDX/Ellpack
        xgb.train(
            {"tree_method": tree_method, "max_bin": max_bin, "device": device},
            Xy,
            num_boost_round=1,
        )
        # query cuts from GIDX/Ellpack
        qXy = xgb.QuantileDMatrix(X[0], y[0], weight=w[0], max_bin=max_bin, ref=Xy)
        predictor_equal(Xy, qXy)
        with pytest.raises(ValueError, match="Inconsistent"):
            # max_bin changed.
            xgb.QuantileDMatrix(X[0], y[0], weight=w[0], max_bin=max_bin - 1, ref=Xy)

        # No error, DMatrix can be modified for different training session.
        xgb.train(
            {"tree_method": tree_method, "max_bin": max_bin - 1, "device": device},
            Xy,
            num_boost_round=1,
        )

        # Init Ellpack/GIDX
        Xy = xgb.QuantileDMatrix(X[0], y[0], weight=w[0], max_bin=max_bin)
        # Init GIDX/Ellpack
        xgb.train(
            {"tree_method": tree_method, "max_bin": max_bin, "device": device},
            Xy,
            num_boost_round=1,
        )
        # query cuts from GIDX/Ellpack
        qXy = xgb.QuantileDMatrix(X[0], y[0], weight=w[0], max_bin=max_bin, ref=Xy)
        predictor_equal(Xy, qXy)
        with pytest.raises(ValueError, match="Inconsistent"):
            # max_bin changed.
            xgb.QuantileDMatrix(X[0], y[0], weight=w[0], max_bin=max_bin - 1, ref=Xy)

        Xy = xgb.DMatrix(X[0], y[0], weight=w[0])
        booster0 = xgb.train(
            {
                "tree_method": "hist",
                "max_bin": max_bin,
                "max_depth": 4,
                "device": "cpu",
            },
            Xy,
            num_boost_round=1,
        )
        booster1 = xgb.train(
            {
                "tree_method": "hist",
                "max_bin": max_bin,
                "max_depth": 4,
                "device": "cuda",
            },
            Xy,
            num_boost_round=1,
        )
        qXy = xgb.QuantileDMatrix(X[0], y[0], weight=w[0], max_bin=max_bin, ref=Xy)
        predt0 = booster0.predict(qXy)
        predt1 = booster1.predict(qXy)
        np.testing.assert_allclose(predt0, predt1)

    @pytest.mark.skipif(**tm.no_cupy())
    @pytest.mark.parametrize(
        "device,max_bin",
        [("cpu", 16), ("cuda", 16), ("cpu", 64), ("cuda", 64)],
    )
    def test_interoperability(self, device: str, max_bin: int) -> None:
        import cupy as cp

        n_samples = 64
        n_features = 3
        X, y, w = tm.make_batches(
            n_samples, n_features=n_features, n_batches=1, use_cupy=False
        )
        # from CPU
        Xy = xgb.QuantileDMatrix(X[0], y[0], weight=w[0], max_bin=max_bin)
        booster_0 = xgb.train(
            {"device": device, "max_bin": max_bin}, Xy, num_boost_round=4
        )

        X[0] = cp.array(X[0])
        y[0] = cp.array(y[0])
        w[0] = cp.array(w[0])

        # from GPU
        Xy = xgb.QuantileDMatrix(X[0], y[0], weight=w[0], max_bin=max_bin)
        booster_1 = xgb.train(
            {"device": device, "max_bin": max_bin}, Xy, num_boost_round=4
        )
        cp.testing.assert_allclose(
            booster_0.inplace_predict(X[0]), booster_1.inplace_predict(X[0])
        )

        with pytest.raises(ValueError, match=r"Only.*hist.*"):
            xgb.train(
                {"tree_method": "approx", "max_bin": max_bin, "device": device},
                Xy,
                num_boost_round=4,
            )

    def test_ref_quantile_cut(self) -> None:
        check_ref_quantile_cut("cuda")

    @pytest.mark.skipif(**tm.no_cupy())
    def test_metainfo(self) -> None:
        import cupy as cp

        rng = cp.random.RandomState(np.uint64(1994))

        rows = 10
        cols = 3
        data = rng.randn(rows, cols)

        labels = rng.randn(rows)

        fw = rng.randn(rows)
        fw -= fw.min()

        m = xgb.QuantileDMatrix(data=data, label=labels, feature_weights=fw)

        got_fw = m.get_float_info("feature_weights")
        got_labels = m.get_label()

        cp.testing.assert_allclose(fw, got_fw)
        cp.testing.assert_allclose(labels, got_labels)

    @pytest.mark.skipif(**tm.no_cupy())
    @pytest.mark.skipif(**tm.no_cudf())
    def test_ref_dmatrix(self) -> None:
        import cupy as cp

        rng = cp.random.RandomState(np.uint64(1994))
        self.cputest.run_ref_dmatrix(rng, "cuda", False)

    @given(
        strategies.integers(1, 1000),
        strategies.integers(1, 100),
        strategies.fractions(0, 0.99),
    )
    @settings(print_blob=True, deadline=None)
    def test_to_csr(self, n_samples: int, n_features: int, sparsity: float) -> None:
        import cupy as cp

        X, y = tm.make_sparse_regression(n_samples, n_features, sparsity, False)
        h_X = X.astype(np.float32)

        csr = h_X
        h_X = X.toarray()
        h_X[h_X == 0] = np.nan

        h_m = xgb.QuantileDMatrix(data=h_X)
        h_ret = h_m.get_data()

        d_X = cp.array(h_X)

        d_m = xgb.QuantileDMatrix(data=d_X, label=y)
        d_ret = d_m.get_data()

        np.testing.assert_equal(csr.indptr, d_ret.indptr)
        np.testing.assert_equal(csr.indices, d_ret.indices)

        np.testing.assert_equal(h_ret.indptr, d_ret.indptr)
        np.testing.assert_equal(h_ret.indices, d_ret.indices)

        booster = xgb.train({"tree_method": "hist", "device": "cuda:0"}, dtrain=d_m)

        np.testing.assert_allclose(
            booster.predict(d_m),
            booster.predict(xgb.DMatrix(d_m.get_data())),
            atol=1e-6,
        )

    def test_ltr(self) -> None:
        import cupy as cp

        X, y, qid, w = tm.make_ltr(100, 3, 3, 5)
        # make sure GPU is used to run sketching.
        cpX = cp.array(X)
        Xy_qdm = xgb.QuantileDMatrix(cpX, y, qid=qid, weight=w)
        Xy = xgb.DMatrix(X, y, qid=qid, weight=w)
        xgb.train({"device": "cuda", "objective": "rank:ndcg"}, Xy)

        from_dm = xgb.QuantileDMatrix(X, weight=w, ref=Xy)
        from_qdm = xgb.QuantileDMatrix(X, weight=w, ref=Xy_qdm)

        assert predictor_equal(from_qdm, from_dm)

    @pytest.mark.skipif(**tm.no_cupy())
    def test_check_inf(self) -> None:
        import cupy as cp

        rng = cp.random.default_rng(1994)
        check_inf(rng)

    def test_mixed_sparsity(self) -> None:
        run_mixed_sparsity("cuda")
